fig = plt.figure(figsize=(7, 7))
axs = Axes3D(fig, elev=10, azim=-70)

for i in range(10):
    axs.scatter(int_out[y_test==i,0], int_out[y_test==i,1], 
                int_out[y_test==i,2], marker='.')
axs.set_xlabel('Neuron 1', fontsize=18)
axs.set_ylabel('Neuron 2', fontsize=18)
axs.set_zlabel('Neuron 3', fontsize=18) 
axs.grid(); 
plt.show()
